#!/bin/bash

thchs30_path=/ceph/home/zhangy20/datasets/data_thchs30

embedding_save_dir=npy

. ./path.sh

stage=3
echo stage $stage

# format data dir structure by soft link
if [ $stage -eq 0 ];then
    if [ ! -d data/wav_files ]; then
        mkdir -p data/wav_files
    fi
    # link and format dir structure
    ./local/thch30_datadir_format.sh $thchs30_path data/wav_files/thchs30
fi

# build data list
if [ $stage -eq 1 ];then
    # prepare for trials and save to binary npy
    python3 $OPENASV_ROOT/scripts/build_datalist.py \
        --extension wav \
        --dataset_dir data/wav_files/thchs30/test \
        --data_list_path data/tmp_list

    # randomly select 1000 utts from thchs30 test as enrollment
    python local/make_enroll_test.py
    rm -rf data/tmp_list

    python3 $OPENASV_ROOT/scripts/build_datalist.py \
        --extension wav \
        --dataset_dir data/wav_files/thchs30/train \
        --data_list_path data/dev_list

	rm -rf data/trials.lst
    python local/make_thch30_trials.py \
        --thchs30_test_dir data/wav_files/thchs30/test \
        --output_trial_path data/trials.lst
fi


if [ $stage -eq 2 ];then
    ckpt_path=ckpt.pt
    CUDA_VISIBLE_DEVICES=0,1 python3 -W ignore $OPENASV_ROOT/main.py \
        --batch_size 400  \
        --num_workers 20 \
        --save_top_k 50 \
        --train_list_path data/dev_list \
        --musan_list_path data/musan_list \
        --rirs_list_path data/rirs_list \
        --embedding_save_dir $embedding_save_dir \
		--loss_type softmax \
		--max_epochs 200 \
        --max_frames 220 --min_frames 200 \
        --learning_rate 0.001 \
		--distributed_backend dp \
		--max_seg_per_spk 500 \
        --trials_path data/trials.lst \
		--eval_interval 5 \
        --checkpoint_path $ckpt_path \
		--gpus 2
fi


if [ $stage -eq 3 ];then
    ckpt_path=ckpt.pt
    rm -rf $embedding_save_dir/

    CUDA_VISIBLE_DEVICES=3 python3 -W ignore $OPENASV_ROOT/main.py \
        --batch_size 10 \
        --num_workers 64 \
        --train_list_path data/dev_list \
        --test_list_path data/dev_list \
        --trials_path trials.lst \
        --gpus 1 \
        --max_frames 201 --min_frames 200 \
        --checkpoint_path $ckpt_path \
        --test
	rm -rf lightning_logs
fi



